{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 判断是否支持GPU"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "False"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import torch\n",
    "torch.cuda.is_available()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "device(type='cpu')"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import torch\n",
    "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "device"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 让模型支持GPU"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "# network、data、label要成为cuda类型\n",
    "def TensorAndCuda(img, label):\n",
    "    img = torch.Tensor(img)\n",
    "    label = torch.Tensor(label)    \n",
    "    if torch.cuda.is_available():\n",
    "        img = img.cuda()\n",
    "        label = label.cuda()\n",
    "        network = network.cuda()\n",
    "    else:\n",
    "        img = Variable(img)\n",
    "        label = Variable(label)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 清除GPU的缓存"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.cuda.empty_cache()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# GPU运行结果计算准确率"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "```python\n",
    "out = network(img)\n",
    "_, predicted = torch.max(out, 1)\n",
    "cpu_info = predicted.cpu()\n",
    "correct = (cpu_info == torch.Tensor(label[0])).sum().item()\n",
    "accuracy = correct / data[0].shape[0]\n",
    "print('validate accuracy: {:.4}'.format(accuracy))\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
